# Importing modules from the libraries
import numpy as np
import pandas as pd
import os
import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline
plt.style.use('ggplot')
import plotly.offline as py
import plotly.graph_objs as go
from plotly.offline import init_notebook_mode
init_notebook_mode(connected=False)
from wordcloud import WordCloud
from geopy.geocoders import Nominatim
from folium.plugins import HeatMap
import folium
from tqdm import tqdm
import re
from keras.preprocessing.text import Tokenizer
from keras.preprocessing.sequence import pad_sequences
from keras.models import Sequential
from keras.layers import Dense, Embedding, LSTM, SpatialDropout1D
from sklearn.model_selection import train_test_split
from nltk import word_tokenize
from sklearn.feature_extraction.text import TfidfVectorizer
import gensim
from collections import Counter
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
import matplotlib.colors as mcolors
from sklearn.manifold import TSNE
from gensim.models import word2vec
import nltk
import warnings
warnings.filterwarnings("ignore")
df = pd.read_csv('zomato.csv')
pd.set_option('display.max_columns', None)
df.head(2)
| url | address | name | online_order | book_table | rate | votes | phone | location | rest_type | dish_liked | cuisines | approx_cost(for two people) | reviews_list | menu_item | listed_in(type) | listed_in(city) | |
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
| 0 | https://www.zomato.com/bangalore/jalsa-banasha... | 942, 21st Main Road, 2nd Stage, Banashankari, ... | Jalsa | Yes | Yes | 4.1/5 | 775 | 080 42297555\r\n+91 9743772233 | Banashankari | Casual Dining | Pasta, Lunch Buffet, Masala Papad, Paneer Laja... | North Indian, Mughlai, Chinese | 800 | [('Rated 4.0', 'RATED\n A beautiful place to ... | [] | Buffet | Banashankari |
| 1 | https://www.zomato.com/bangalore/spice-elephan... | 2nd Floor, 80 Feet Road, Near Big Bazaar, 6th ... | Spice Elephant | Yes | No | 4.1/5 | 787 | 080 41714161 | Banashankari | Casual Dining | Momos, Lunch Buffet, Chocolate Nirvana, Thai G... | Chinese, North Indian, Thai | 800 | [('Rated 4.0', 'RATED\n Had been here for din... | [] | Buffet | Banashankari |
print("dataset contains {} rows and {} columns".format(df.shape[0],df.shape[1]))
dataset contains 51717 rows and 17 columns
df = df.drop(['address','phone'],axis=1)
df.columns
Index(['url', 'name', 'online_order', 'book_table', 'rate', 'votes',
'location', 'rest_type', 'dish_liked', 'cuisines',
'approx_cost(for two people)', 'reviews_list', 'menu_item',
'listed_in(type)', 'listed_in(city)'],
dtype='object')
df = df.rename(columns={'approx_cost(for two people)': 'cost', 'listed_in(type)': 'type',
'listed_in(city)': 'city'})
df.info()
<class 'pandas.core.frame.DataFrame'> RangeIndex: 51717 entries, 0 to 51716 Data columns (total 15 columns): # Column Non-Null Count Dtype --- ------ -------------- ----- 0 url 51717 non-null object 1 name 51717 non-null object 2 online_order 51717 non-null object 3 book_table 51717 non-null object 4 rate 43942 non-null object 5 votes 51717 non-null int64 6 location 51696 non-null object 7 rest_type 51490 non-null object 8 dish_liked 23639 non-null object 9 cuisines 51672 non-null object 10 cost 51371 non-null object 11 reviews_list 51717 non-null object 12 menu_item 51717 non-null object 13 type 51717 non-null object 14 city 51717 non-null object dtypes: int64(1), object(14) memory usage: 5.9+ MB
Columns description
url contains the url of the restaurant in the zomato website
address contains the address of the restaurant in Bengaluru
name contains the name of the restaurant
online_order whether online ordering is available in the restaurant or not
book_table table book option available or not
rate contains the overall rating of the restaurant out of 5
votes contains total number of rating for the restaurant as of the above mentioned date
phone contains the phone number of the restaurant
location contains the neighborhood in which the restaurant is located
rest_type restaurant type
dish_liked dishes people liked in the restaurant
cuisines food styles, separated by comma
approx_cost(for two people) contains the approximate cost for meal for two people
reviews_list list of tuples containing reviews for the restaurant, each tuple
menu_item contains list of menus available in the restaurant
listed_in(type) type of meal
listed_in(city) contains the neighborhood in which the restaurant is listed
Value_Missing = df.isnull().sum().sort_values(ascending=False)
percent_missing = (df.isnull().sum() * 100 / len(df)).round(2).sort_values(ascending=False)
missing_data = pd.concat([Value_Missing, percent_missing], axis=1, keys=['Total', 'Percent'])
f, ax = plt.subplots(figsize = (6,4))
plt.xticks(rotation = '90')
sns.barplot(x = missing_data.index, y = missing_data['Percent'], palette='plasma')
plt.xlabel('Features', fontsize = 10)
plt.ylabel('Percent of missing values', fontsize = 10)
plt.title('Percent missing data by feature', fontsize = 15)
missing_data.head(7)
| Total | Percent | |
|---|---|---|
| dish_liked | 28078 | 54.29 |
| rate | 7775 | 15.03 |
| cost | 346 | 0.67 |
| rest_type | 227 | 0.44 |
| cuisines | 45 | 0.09 |
| location | 21 | 0.04 |
| url | 0 | 0.00 |
top_10 = df.name.value_counts()[:10]
plt.figure(figsize=(15,10))
ax=top_10.plot(kind='pie',shadow=True, startangle=140, autopct='%1.2f%%')
plt.title('Name Percentage', weight='bold')
plt.axis('equal')
plt.show()
x=df['online_order'].value_counts()
colors = ['#FEBFB3', '#E1396C']
trace=go.Pie(labels=x.index,values=x,textinfo="value",
marker=dict(colors=colors,
line=dict(color='#000000', width=2)))
layout=go.Layout(title="Accepting vs not accepting online orders",width=500,height=500)
fig=go.Figure(data=[trace],layout=layout)
py.iplot(fig, filename='pie_chart_subplots')
x=df['book_table'].value_counts()
colors = ['#96D38C', '#D0F9B1']
trace=go.Pie(labels=x.index,values=x,textinfo="value",
marker=dict(colors=colors,
line=dict(color='#000000', width=2)))
layout=go.Layout(title="Table booking",width=500,height=500)
fig=go.Figure(data=[trace],layout=layout)
py.iplot(fig, filename='pie_chart_subplots')
#reading rate of the dataset :
df['rate'].unique()
array(['4.1/5', '3.8/5', '3.7/5', '3.6/5', '4.6/5', '4.0/5', '4.2/5',
'3.9/5', '3.1/5', '3.0/5', '3.2/5', '3.3/5', '2.8/5', '4.4/5',
'4.3/5', 'NEW', '2.9/5', '3.5/5', nan, '2.6/5', '3.8 /5', '3.4/5',
'4.5/5', '2.5/5', '2.7/5', '4.7/5', '2.4/5', '2.2/5', '2.3/5',
'3.4 /5', '-', '3.6 /5', '4.8/5', '3.9 /5', '4.2 /5', '4.0 /5',
'4.1 /5', '3.7 /5', '3.1 /5', '2.9 /5', '3.3 /5', '2.8 /5',
'3.5 /5', '2.7 /5', '2.5 /5', '3.2 /5', '2.6 /5', '4.5 /5',
'4.3 /5', '4.4 /5', '4.9/5', '2.1/5', '2.0/5', '1.8/5', '4.6 /5',
'4.9 /5', '3.0 /5', '4.8 /5', '2.3 /5', '4.7 /5', '2.4 /5',
'2.1 /5', '2.2 /5', '2.0 /5', '1.8 /5'], dtype=object)
cost_dist=df[['rate','cost','online_order']].dropna()
cost_dist['rate']=cost_dist['rate'].apply(lambda x: float(x.split('/')[0]) if len(x)>3 else 0)
cost_dist['cost']=cost_dist['cost'].apply(lambda x: int(x.replace(',','')))
plt.figure(figsize=(10,7))
sns.scatterplot(x="rate",y='cost',hue='online_order',data=cost_dist)
plt.show()
#removing '/5' from rates :
df = df.loc[df.rate != 'NEW']
df = df.loc[df.rate != '-'].reset_index(drop=True)
remove_slash = lambda x: x.replace('/5', '') if type(x) == np.str else x
df.rate = df.rate.apply(remove_slash).str.strip().astype('float')
df['rate'].head()
0 4.1 1 4.1 2 3.8 3 3.7 4 3.8 Name: rate, dtype: float64
df['rate'].value_counts(normalize=False, sort=False, ascending=False, bins=5)
(1.796, 2.42] 187 (2.42, 3.04] 3093 (3.04, 3.66] 14320 (3.66, 4.28] 19981 (4.28, 4.9] 4084 Name: rate, dtype: int64
plt.figure(figsize=(7,7))
rest=df['rest_type'].value_counts()[:20]
sns.barplot(rest,rest.index, palette='rainbow')
plt.title("Restaurant types")
plt.xlabel("count");
plt.figure(figsize=(12,6))
colors = ['#00FF7F', '#63B8FF', '#FF6347', '#EE82EE', '#FFFF00', '#00868B', '#27408B', '#8B0000', '#33A1C9', '#FF8000']
explode = (0.1, 0, 0, 0,0,0,0,0,0,0)
ax=df.location.value_counts()[:10].plot(kind='pie',explode=explode,colors=colors,shadow=True,startangle=140,autopct='%1.2f%%')
plt.title('Location Percentage', weight='bold')
plt.axis('equal')
plt.show()
plt.figure(figsize = (10, 5))
Preferred_type = df['type'].value_counts()
sns.barplot(Preferred_type.index, Preferred_type, palette = "rocket");
df['cost'].unique()
array(['800', '300', '600', '700', '550', '500', '450', '650', '400',
'900', '200', '750', '150', '850', '100', '1,200', '350', '250',
'950', '1,000', '1,500', '1,300', '199', '1,100', '1,600', '230',
'130', '80', '50', '190', '1,700', nan, '180', '1,350', '2,200',
'1,400', '2,000', '1,800', '1,900', '330', '2,500', '2,100',
'3,000', '2,800', '3,400', '40', '1,250', '3,500', '4,000',
'2,400', '2,600', '120', '1,450', '469', '70', '3,200', '60',
'240', '6,000', '1,050', '2,300', '4,100', '5,000', '3,700',
'1,650', '2,700', '4,500', '140', '360'], dtype=object)
plt.figure(figsize=(12,8))
explode=(0.1,0,0,0,0,0,0,0,0,0)
colors = ("#8470FF", "#FFB6C1", "#FFF68F", "#8B3A3A", "#00EE00", "#228B22", "#FF7D40", "#00BFFF", "#9400D3", "#FF8C00")
ax = df['cost'].value_counts()[:10].plot(kind='pie',autopct='%1.2f%%',explode = explode, colors = colors,fontsize = 15)
#draw circle
centre_circle=plt.Circle((0,0),0.8,fc='white')
fig=plt.gcf()
fig.gca().add_artist(centre_circle)
plt.title('Average cost for two person(in %) ', weight='bold')
plt.xlabel('Cost')
plt.show()
df['cost'] = df['cost'].astype(str)
df['cost'] = df['cost'].apply(lambda x: x.replace(',','.')) #Using lambda function to replace ',' from cost
df['cost'] = df['cost'].astype(float)
df['cost'].describe()
count 49099.000000 mean 361.297400 std 231.111464 min 1.000000 25% 200.000000 50% 350.000000 75% 500.000000 max 950.000000 Name: cost, dtype: float64
votes_yes=df[df['online_order']=="Yes"]['votes']
trace0=go.Box(y=votes_yes,name="accepting online orders",
marker = dict(
color = 'rgb(214, 12, 140)',
))
votes_no=df[df['online_order']=="No"]['votes']
trace1=go.Box(y=votes_no,name="Not accepting online orders",
marker = dict(
color = 'rgb(0, 128, 128)',
))
layout = go.Layout(
title = "Box Plots of votes",width=800,height=500
)
data=[trace0,trace1]
fig=go.Figure(data=data,layout=layout)
py.iplot(fig)
locations=pd.DataFrame({"Name":df['location'].unique()})
locations['Name']=locations['Name'].apply(lambda x: "Bangalore " + str(x))
lat_lon=[]
geolocator=Nominatim(user_agent="app")
for location in locations['Name']:
location = geolocator.geocode(location)
if location is None:
lat_lon.append(np.nan)
else:
geo=(location.latitude,location.longitude)
lat_lon.append(geo)
locations['geo_loc']=lat_lon
locations.to_csv('locations.csv',index=False)
locations["Name"]=locations['Name'].apply(lambda x : x.replace("Bangalore","")[1:])
locations.head()
| Name | geo_loc | |
|---|---|---|
| 0 | Banashankari | (12.9152208, 77.573598) |
| 1 | Basavanagudi | (12.9417261, 77.5755021) |
| 2 | Mysore Road | (12.3669864, 76.6617233) |
| 3 | Jayanagar | (12.9292731, 77.5824229) |
| 4 | Kumaraswamy Layout | (12.9081487, 77.5553179) |
Rest_locations=pd.DataFrame(df['location'].value_counts().reset_index())
Rest_locations.columns=['Name','count']
Rest_locations=Rest_locations.merge(locations,on='Name',how="left").dropna()
Rest_locations['count'].max()
4793
def generateBaseMap(default_location=[12.97, 77.59], default_zoom_start=12):
base_map = folium.Map(location=default_location, control_scale=True, zoom_start=default_zoom_start)
return base_map
lat,lon=zip(*np.array(Rest_locations['geo_loc']))
Rest_locations['lat']=lat
Rest_locations['lon']=lon
basemap=generateBaseMap()
HeatMap(Rest_locations[['lat','lon','count']].values.tolist(),zoom=20,radius=15).add_to(basemap)
<folium.plugins.heat_map.HeatMap at 0x21f0203cac0>
basemap
plt.figure(figsize=(7,5))
cuisines=df['cuisines'].value_counts()[:10]
sns.barplot(cuisines,cuisines.index)
plt.xlabel('Count')
plt.title("Most popular cuisines of Bangalore");
def produce_data(col,name):
data= pd.DataFrame(df[df[col]==name].groupby(['location'],as_index=False)['url'].agg('count'))
data.columns=['Name','count']
print(data.head())
data=data.merge(locations,on="Name",how='left').dropna()
data['lan'],data['lon']=zip(*data['geo_loc'].values)
return data.drop(['geo_loc'],axis=1)
North_India=produce_data('cuisines','North Indian')
Name count 0 BTM 319 1 Banashankari 28 2 Banaswadi 12 3 Bannerghatta Road 72 4 Basavanagudi 20
basemap=generateBaseMap()
HeatMap(North_India[['lan','lon','count']].values.tolist(),zoom=20,radius=15).add_to(basemap)
basemap
South_India=produce_data('cuisines','South Indian')
Name count 0 BTM 106 1 Banashankari 81 2 Banaswadi 37 3 Bannerghatta Road 44 4 Basavanagudi 87
basemap=generateBaseMap()
HeatMap(South_India[['lan','lon','count']].values.tolist(),zoom=20,radius=15).add_to(basemap)
basemap
def produce_chains(name):
data_chain=pd.DataFrame(df[df["name"]==name]['location'].value_counts().reset_index())
data_chain.columns=['Name','count']
data_chain=data_chain.merge(locations,on="Name",how="left").dropna()
data_chain['lan'],data_chain['lon']=zip(*data_chain['geo_loc'].values)
return data_chain[['Name','count','lan','lon']]
df_1=df.groupby(['rest_type','name']).agg('count')
datas=df_1.sort_values(['url'],ascending=False).groupby(['rest_type'],
as_index=False).apply(lambda x : x.sort_values(by="url",ascending=False).head(3))['url'].reset_index().rename(columns={'url':'count'})
mapbox_access_token="pk.eyJ1Ijoic2hhaHVsZXMiLCJhIjoiY2p4ZTE5NGloMDc2YjNyczBhcDBnZnA5aCJ9.psBECQ2nub0o25PgHcU88w"
def produce_trace(data_chain,name):
data_chain['text']=data_chain['Name']+'<br>'+data_chain['count'].astype(str)
trace = go.Scattermapbox(
lat=data_chain['lan'],
lon=data_chain['lon'],
mode='markers',
marker=go.scattermapbox.Marker(
size=data_chain['count']*4
),
text=data_chain['text'],name=name
)
return trace
quick=datas[datas['rest_type']=='Quick Bites']
quick
| level_0 | rest_type | name | count | |
|---|---|---|---|---|
| 179 | 78 | Quick Bites | Five Star Chicken | 69 |
| 180 | 78 | Quick Bites | Domino's Pizza | 59 |
| 181 | 78 | Quick Bites | McDonald's | 59 |
data=[]
for row in quick['name']:
data_chain=produce_chains(row)
trace_0=produce_trace(data_chain,row)
data.append(trace_0)
layout = go.Layout(title="Quick Bites Restaurant chains locations around Banglore",
autosize=True,
hovermode='closest',
mapbox=dict(
accesstoken=mapbox_access_token,
bearing=0,style="streets",
center=dict(
lat=12.96,
lon=77.59
),
pitch=0,
zoom=10
),
)
fig = dict(data=data, layout=layout)
py.iplot(fig, filename='Montreal Mapbox')
casual=datas[datas['rest_type']=='Casual Dining']
casual
| level_0 | rest_type | name | count | |
|---|---|---|---|---|
| 59 | 27 | Casual Dining | Empire Restaurant | 58 |
| 60 | 27 | Casual Dining | Mani's Dum Biryani | 47 |
| 61 | 27 | Casual Dining | Chung Wah | 46 |
data=[]
for row in casual['name']:
data_chain=produce_chains(row)
trace_0=produce_trace(data_chain,row)
data.append(trace_0)
layout = go.Layout(title="Casual Dining Restaurant chains locations around Banglore",
autosize=True,
hovermode='closest',
mapbox=dict(
accesstoken=mapbox_access_token,
bearing=0,style="streets",
center=dict(
lat=12.96,
lon=77.59
),
pitch=0,
zoom=10
),
)
fig = dict(data=data, layout=layout)
py.iplot(fig, filename='Montreal Mapbox')
cafe=datas[datas['rest_type']=='Cafe']
cafe
| level_0 | rest_type | name | count | |
|---|---|---|---|---|
| 41 | 19 | Cafe | Cafe Coffee Day | 93 |
| 42 | 19 | Cafe | Smally's Resto Cafe | 54 |
| 43 | 19 | Cafe | Mudpipe Cafe | 39 |
data=[]
for row in cafe['name']:
data_chain=produce_chains(row)
trace_0=produce_trace(data_chain,row)
data.append(trace_0)
layout = go.Layout(title="Cafe Restaurant chains locations around Banglore",
autosize=True,
hovermode='closest',
mapbox=dict(
accesstoken=mapbox_access_token,
bearing=0,style="streets",
center=dict(
lat=12.96,
lon=77.59
),
pitch=0,
zoom=10
),
)
fig = dict(data=data, layout=layout)
py.iplot(fig, filename='Montreal Mapbox')
all_ratings = []
for name,ratings in tqdm(zip(df['name'],df['reviews_list'])):
ratings = eval(ratings)
for score, doc in ratings:
if score:
score = score.strip("Rated").strip()
doc = doc.strip('RATED').strip()
score = float(score)
all_ratings.append([name,score, doc])
49440it [00:39, 1255.72it/s]
rating_df=pd.DataFrame(all_ratings,columns=['name','rating','review'])
rating_df['review']=rating_df['review'].apply(lambda x : re.sub('[^a-zA-Z0-9\s]',"",x))
rating_df.to_csv("Ratings.csv")
rating_df.head()
| name | rating | review | |
|---|---|---|---|
| 0 | Jalsa | 4.0 | A beautiful place to dine inThe interiors take... |
| 1 | Jalsa | 4.0 | I was here for dinner with my family on a week... |
| 2 | Jalsa | 2.0 | Its a restaurant near to Banashankari BDA Me a... |
| 3 | Jalsa | 4.0 | We went here on a weekend and one of us had th... |
| 4 | Jalsa | 5.0 | The best thing about the place is its ambiance... |
plt.figure(figsize=(7,6))
rating=rating_df['rating'].value_counts()
sns.barplot(x=rating.index,y=rating)
plt.xlabel("Ratings")
plt.ylabel('count')
Text(0, 0.5, 'count')
We will do topic modelling for postive and negative comments seperately to understand the different between the two types.
rating_df['sent']=rating_df['rating'].apply(lambda x: 1 if int(x)>2.5 else 0)
Now,
#import nltk
nltk.download('stopwords')
nltk.download('wordnet')
nltk.download('punkt')
[nltk_data] Downloading package stopwords to C:\Users\Amrita [nltk_data] Sarkar\AppData\Roaming\nltk_data... [nltk_data] Package stopwords is already up-to-date! [nltk_data] Downloading package wordnet to C:\Users\Amrita [nltk_data] Sarkar\AppData\Roaming\nltk_data... [nltk_data] Package wordnet is already up-to-date! [nltk_data] Downloading package punkt to C:\Users\Amrita [nltk_data] Sarkar\AppData\Roaming\nltk_data... [nltk_data] Package punkt is already up-to-date!
True
stops=stopwords.words('english')
lem = WordNetLemmatizer()
corpus=' '.join(lem.lemmatize(x) for x in rating_df[rating_df['sent']==1]['review'][:3000] if x not in stops)
tokens=word_tokenize(corpus)
vect=TfidfVectorizer()
vect_fit=vect.fit(tokens)
id_map=dict((v,k) for k,v in vect.vocabulary_.items())
vectorized_data=vect_fit.transform(tokens)
gensim_corpus=gensim.matutils.Sparse2Corpus(vectorized_data,documents_columns=False)
ldamodel = gensim.models.ldamodel.LdaModel(gensim_corpus,id2word=id_map,num_topics=5,random_state=34,passes=25)
counter=Counter(corpus)
out=[]
topics=ldamodel.show_topics(formatted=False)
for i,topic in topics:
for word,weight in topic:
out.append([word,i,weight,counter[word]])
dataframe = pd.DataFrame(out, columns=['word', 'topic_id', 'importance', 'word_count'])
# Plot Word Count and Weights of Topic Keywords
fig, axes = plt.subplots(2, 2, figsize=(8,6), sharey=True, dpi=160)
cols = [color for name, color in mcolors.TABLEAU_COLORS.items()]
for i, ax in enumerate(axes.flatten()):
ax.bar(x='word', height="word_count", data=dataframe.loc[dataframe.topic_id==i, :], color=cols[i], width=0.3, alpha=0.3, label='Word Count')
ax_twin = ax.twinx()
ax_twin.bar(x='word', height="importance", data=dataframe.loc[dataframe.topic_id==i, :], color=cols[i], width=0.2, label='Weights')
ax.set_ylabel('Word Count', color=cols[i])
#ax_twin.set_ylim(0, 0.030); ax.set_ylim(0, 3500)
ax.set_title('Topic: ' + str(i), color=cols[i], fontsize=8)
ax.tick_params(axis='y', left=False)
ax.set_xticklabels(dataframe.loc[dataframe.topic_id==i, 'word'], rotation=30, horizontalalignment= 'right')
ax.legend(loc='upper left'); ax_twin.legend(loc='upper right')
fig.tight_layout(w_pad=2)
fig.suptitle('Word Count and Importance of Topic Keywords', fontsize=8, y=1.05)
plt.show()
Sentiment Analysis is the process of computationally determining whether a piece of writing is positive, negative or neutral. It’s also known as opinion mining, deriving the opinion or attitude of a speaker.

For doing sentimental analysis on reviews provided bt users.We have to prepare our data in appropriate format. We will map reviews to positive and negative on the basis of the ratings provided by each user.So,we will map reviews to negative if the rating given is less than 2.5 and positive if rating is greater than 2.5.
rating_df['sent']=rating_df['rating'].apply(lambda x: 1 if int(x)>2.5 else 0)
max_features=3000
tokenizer=Tokenizer(num_words=max_features,split=' ')
tokenizer.fit_on_texts(rating_df['review'].values)
X = tokenizer.texts_to_sequences(rating_df['review'].values)
X = pad_sequences(X)
embed_dim = 32
lstm_out = 32
model = Sequential()
model.add(Embedding(max_features, embed_dim,input_length = X.shape[1]))
#model.add(SpatialDropout1D(0.4))
model.add(LSTM(lstm_out, dropout=0.2, recurrent_dropout=0.2))
model.add(Dense(2,activation='softmax'))
model.compile(loss = 'categorical_crossentropy', optimizer='adam',metrics = ['accuracy'])
print(model.summary())
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= embedding (Embedding) (None, 194, 32) 96000 _________________________________________________________________ lstm (LSTM) (None, 32) 8320 _________________________________________________________________ dense (Dense) (None, 2) 66 ================================================================= Total params: 104,386 Trainable params: 104,386 Non-trainable params: 0 _________________________________________________________________ None
Y = pd.get_dummies(rating_df['sent'].astype(int)).values
X_train, X_test, Y_train, Y_test = train_test_split(X,Y, test_size = 0.33, random_state = 42)
print(X_train.shape,Y_train.shape)
print(X_test.shape,Y_test.shape)
(880394, 194) (880394, 2) (433628, 194) (433628, 2)
batch_size = 3200
model.fit(X_train, Y_train, epochs = 2, batch_size=batch_size)
Epoch 1/2 276/276 [==============================] - 2784s 10s/step - loss: 0.2483 - accuracy: 0.8997 Epoch 2/2 276/276 [==============================] - 3207s 12s/step - loss: 0.1441 - accuracy: 0.9442
<keras.callbacks.History at 0x22f1dfaf640>
We will take 1500 rows to validate our model.We have choosen accuacy to be our evaluation criteria.
validation_size = 1500
X_validate = X_test[-validation_size:]
Y_validate = Y_test[-validation_size:]
X_test = X_test[:-validation_size]
Y_test = Y_test[:-validation_size]
score,acc = model.evaluate(X_test, Y_test, verbose = 2, batch_size = batch_size)
print("score: %.2f" % (score))
print("acc: %.2f" % (acc))
136/136 - 58s - loss: 0.1343 - accuracy: 0.9502 score: 0.13 acc: 0.95